import pymysql, yaml, json, os
from jinja2 import FileSystemLoader, Environment
import argparse

parser = argparse.ArgumentParser(description="代码生成器")
parser.add_argument("-c", "--config", help="配置文件")
parser.add_argument("-d", "--dataType", help="数据类型转换配置文件")
parser.add_argument("-t", "--tableName", help="要代码生成的表")
parser.add_argument("-m", "--mode", help="1->生成代码,2->查看元数据")
args = parser.parse_args()


def loadConfig():
    """
    加载配置信息
    """
    configPath = args.config
    if configPath == None:
        configPath = "config/app.yml"
    with open(configPath, 'r', encoding='utf-8') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config


config = loadConfig()


def loadTypeConfig():
    """
    加载type配置信息
    """
    dataTypeConfig = args.dataType
    if dataTypeConfig == None:
        dataTypeConfig = "config/dataType.yml"
    with open(dataTypeConfig, 'r', encoding='utf-8') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config


typeConfig = loadTypeConfig()


def getAllTables(dbName):
    """
    查询数据库名称为db_name的所有表
    :param dbName: 数据库名称
    :return:
    """
    sql = f'''
        SELECT t.table_catalog,t.table_schema,t.table_name,table_type 
        FROM information_schema.TABLES t where t.table_schema='{dbName}'
     '''
    return execute(sql)


def getTable(dbName, tableName):
    """
    查询数据库名称为db_name，表名为tb_name的表
    :param dbName: 数据库名
    :param tableName: 表名
    :return:
    """
    sql = f'''
        SELECT t.table_catalog,t.table_schema,t.table_name,table_type 
        FROM information_schema.TABLES t 
        where t.table_schema='{dbName}' and t.table_name = '{tableName}'
    '''
    res = execute(sql)
    if len(res) > 0:
        return res[0]
    return {}


def getTablesLike(dbName, tableNameLike):
    """
    查询数据库名称为db_name，以sys_开头的表
    :param dbName: 数据库名
    :param tableNameLike: 模糊表名
    :return:
    """
    sql = f'''
        SELECT t.table_catalog,t.table_schema,t.table_name,table_type 
        FROM information_schema.TABLES t 
        where t.table_schema='{dbName}' and t.table_name like '{tableNameLike}'
    '''
    return execute(sql)


def getTableComment(tableName):
    """
    获取表注释
    :param tableName: 表名
    :return:
    """
    sql = f'''
    show table status where NAME='{tableName}'
    '''
    res = execute(sql)
    if len(res) > 0:
        return {
            "name": res[0].get("name", res[0].get("Name", res[0].get("NAME", ''))),
            "comment": res[0].get("comment", res[0].get("Comment", res[0].get("COMMENT", '')))
        }
    return {}


def getPrimaryColumns(dbName, tableName):
    """
    获取表主键列
    :return:
    """
    sql = f'''
        select k.column_name from information_schema.table_constraints t 
        join information_schema.key_column_usage k
        using (constraint_name,table_schema,table_name) 
        where t.constraint_type='PRIMARY KEY' 
        and t.table_schema='{dbName}' and t.table_name='{tableName}'
    '''
    return execute(sql)


def getTableColumns(dbName, tableName):
    """
    获取某个表的所有列
    :param dbName:
    :param tableName:
    :return:
    """
    sql = f'''
        SELECT t.table_schema,t.table_name,t.column_name,t.column_default,
        t.is_nullable,t.data_type,t.character_maximum_length,t.numeric_precision,
        t.numeric_scale,t.column_type,t.column_key, t.column_comment 
        FROM information_schema.columns t
        WHERE t.table_schema = '{dbName}' AND t.table_name = '{tableName}'
    '''
    return execute(sql)


def execute(sql):
    """
    执行sql
    :param sql: 执行的sql
    :return:
    """
    host = config['database']['url']
    username = config['database']['username']
    password = config['database']['password']
    dbName = config['database']['dbName']
    port = config['database']['port']
    db = pymysql.connect(host=host, port=port, user=username, passwd=password, db=dbName, charset="utf8")
    cursor = db.cursor()
    cursor.execute(sql)
    data = cursor.fetchall()
    index_dict = {}
    index = 0
    for desc in cursor.description:
        index_dict[desc[0]] = index
        index = index + 1
    res = []
    for datai in data:
        item = {}
        for idx in index_dict:
            item[idx] = datai[index_dict[idx]]
        res.append(item)
    db.close()
    return res


def str2Hump(text, flag=False):
    """
    下划线转驼峰
    :param text:
    :param flag 首字母是否大写，默认false
    :return:
    """
    if flag == True:
        return "".join(map(lambda x: x.capitalize(), text.split("_")))
    arr = filter(None, text.lower().split('_'))
    res = ''
    j = 0
    for i in arr:
        if j == 0:
            res = i
        else:
            res = res + i[0].upper() + i[1:]
        j += 1
    return res


def tranferTable(table):
    """
    元数据转换表结构
    :param dbName:
    :param tableName:
    :return:
    """
    dbName = table['table_schema']
    tableName = table['table_name']
    res = {
        "tableName": tableName,
        "tableCameName": str2Hump(tableName),
        "entityName": str2Hump(tableName, flag=True),
        "className": str2Hump(tableName, flag=True),
        "schema": dbName,
        "remark": getTableComment(tableName)['comment']
    }
    tableColumns = getTableColumns(dbName, tableName)
    columns = []
    for tableColumn in tableColumns:
        columns.append(tranferColumn(tableColumn))
    res['columns'] = columns
    primaryKeys = []
    for column in columns:
        if column['primaryKey']:
            primaryKeys.append(column)
    res['primaryKeys'] = primaryKeys
    return res


def tranferColumn(tableColumn):
    """
    元数据转换-列
    :param tableColumn:
    :return:
    """
    size = 0
    decimalDigits = 0
    if tableColumn['character_maximum_length'] != None:
        size = tableColumn['character_maximum_length']
    if tableColumn['numeric_precision'] != None:
        size = tableColumn['numeric_precision']
    if tableColumn['numeric_scale'] != None:
        decimalDigits = tableColumn['numeric_scale']
    type = typeConfig['dataType'].get(tableColumn['data_type'].upper(), {
        "javaType": "String",
        "fullJavaType": "java.lang.String"
    })
    res = {
        "tableName": tableColumn['table_name'],
        "columnName": tableColumn['column_name'],
        "propertyName": str2Hump(tableColumn['column_name']),
        "primaryKey": tableColumn['column_key'] == 'PRI',
        "foreignKey": False,
        "size": size,
        "decimalDigits": decimalDigits,
        "nullable": tableColumn['is_nullable'] == 'YES',
        "autoincrement": False,
        "defaultValue": tableColumn['column_default'],
        "remark": tableColumn['column_comment'],
        "dataType": tableColumn['data_type'],
        "javaType": type['javaType'],
        "fullJavaType": type["fullJavaType"],
        "setterMethodName": "set" + str2Hump(tableColumn['column_name'], flag=True),
        "getterMethodName": "get" + str2Hump(tableColumn['column_name'], flag=True)
    }
    return res


def buildFromConfig(config):
    """
    根据配置文件构造
    :param config:
    :return:
    """
    dbName = config['database']['dbName']
    tables = config['tables']
    res = []
    for table in tables:
        getTables = getTablesLike(dbName, table['tableName'])
        for item in getTables:
            res.append(tranferTable(getTable(dbName, item['table_name'])))
    return res


def buildByTableName(config, tableName):
    """
    通过命令中-t获取表构造
    """
    dbName = config['database']['dbName']
    res = []
    getTables = getTablesLike(dbName, tableName)
    for item in getTables:
        res.append(tranferTable(getTable(dbName, item['table_name'])))
    return res


"""
字符串首字母转小写
"""


def uncap_first(str):
    return str[0:1].lower() + str[1:len(str)]


def genCode(tables, config):
    """
    生成代码
    """
    env = Environment(loader=FileSystemLoader('templates'))
    env.filters['uncap_first'] = uncap_first
    templates = config['templates']
    for table in tables:
        templateData = {}
        templateData.update(config)
        templateData.update({"table": table})
        for item in templates:
            if item['selected'] == False:
                break;
            template = env.get_template(item['templateFile'])
            path = config['targetProject'] + item['targetPath']
            # 配置参数替换-模板引擎
            path = env.from_string(path).render(templateData)
            # 替换包名为目录
            path = path.replace(".", "/")
            if not os.path.exists(path):
                os.makedirs(path)
            # 配置参数替换-模板引擎
            targetFileName = env.from_string(item['targetFileName']).render(templateData)
            dist = path + targetFileName

            if os.path.exists(dist):
                if item['covered'] == True:
                    with open(dist, 'w', encoding=item['encoding']) as f:
                        html = template.render(templateData)
                        f.write(html)
                        print(f"{templateData['table']['tableName']}表代码生成成功-覆盖：{dist}")
                else:
                    print(f"{dist}文件已存在，不覆盖")
            else:
                with open(dist, 'w', encoding=item['encoding']) as f:
                    html = template.render(templateData)
                    f.write(html)
                    print(f"{templateData['table']['tableName']}表代码生成成功-新生成：{dist}")
    return 1


if __name__ == "__main__":
    # print(getAllTables('mysql'))
    # print(getTable('mysql','user'))
    # print(getTablesLike('mysql','help_%'))
    # print(getTableComment('user'))
    # print(getPrimaryColumns('mysql','user'))
    # print(getTableColumns('mysql','user'))
    tableName = args.tableName
    mode = args.mode
    if tableName != None:
        tables = buildByTableName(config, tableName)
    else:
        tables = buildFromConfig(config)
    if mode == "2":
        print(json.dumps(tables, indent=2, ensure_ascii=False))
    else:
        genCode(tables, config)
